package mysql

import (
	"fmt"
	"io"
	"strings"
	"text/template"

	"gitee.com/lipore/plume/db_gorm/gen/code"
	"gitee.com/lipore/plume/db_gorm/gen/spec"
)

type updateMethodData struct {
	repositoryMethodData
	intOrBoolReturnData
	Entity code.Type
	Query  string
	Update spec.Update
}

var updateMethodTemplate = `
{{block "repositoryMethod" .}}{{end}} {
	entity := &{{typeToCode .Entity}}{}
    conn, err := db_gorm.LoadTx(ctx)
	if err != nil{
	{{- block "errorReturnStmt" .}}{{end -}}
	}
	var result *gorm.DB
	err = conn.Transaction(func(tx *gorm.DB) error {
		{{block "updateStmt" .}}{{end}}
	})
	if err != nil {
	{{- block "errorReturnStmt" .}}{{end -}}
	}
	{{block "returnStmt" .}}{{end}}
}`

var updateFieldsTemplate = `{{define "updateStmt" -}}
    result = tx.Model(&entity){{.Query}}.Updates(map[string]interface{}{ {{buildUpdateStmt .Update .Spec.Params}} })
	return result.Error
{{end}}`

func buildUpdateStmt(fields spec.UpdateFields, params []code.Param) string {
	fieldStmts := make([]string, len(fields))
	for i, field := range fields {
		column := fieldToColumn(field.FieldReference.ReferencedField())
		value := params[field.ParamIndex].Name
		fieldStmts[i] = fmt.Sprintf("\"%s\": %s", column, value)
	}
	return strings.Join(fieldStmts, ", ")
}

var updateModelTemplate = `{{define "updateStmt" -}}
	entity.fromDomainObject({{getModelName .Spec.Params}})
	result = tx{{.Query}}.Updates(&entity)
	return result.Error
{{- end}}`

func getModelName(params []code.Param) string {
	return params[1].Name
}

func renderUpdateMethod(s spec.MethodSpec, entity code.Type, repository code.Type, operation spec.UpdateOperation, writer io.Writer) error {
	tmpl := template.New("update Method")
	tmpl, err := renderRepositoryMethod(tmpl)
	if err != nil {
		return err
	}

	_, updateFieldsMode := operation.Update.(spec.UpdateFields)
	if updateFieldsMode {
		tmpl, err = tmpl.Funcs(map[string]any{"buildUpdateStmt": buildUpdateStmt}).Parse(updateFieldsTemplate)
	} else {
		tmpl, err = tmpl.Funcs(map[string]any{"getModelName": getModelName}).Parse(updateModelTemplate)
	}
	if err != nil {
		return err
	}

	tmpl, err = renderReturnStmt(tmpl)

	tmpl, err = tmpl.Parse(updateMethodTemplate)
	if err != nil {
		return err
	}

	query, err := renderQueryStmt(s, operation.Query)
	if err != nil {
		return err
	}
	tmplData := updateMethodData{
		repositoryMethodData: repositoryMethodData{Receiver: code.Param{Name: "repository", Type: code.PointerType{ContainedType: repository}}, Spec: s},
		intOrBoolReturnData: intOrBoolReturnData{
			IntReturn: operation.Mode == spec.QueryModeMany,
		},
		Entity: entity,
		Query:  query,
		Update: operation.Update,
	}
	return tmpl.Execute(writer, tmplData)
}
